#!/usr/bin/env python
# -*- coding: utf-8 -*-

import sys,os
# sys.path.append(r"/home/yh579/GAFM/exp5")
import torch
from scipy.stats import pearsonr
from torchvision.utils import save_image
import torch.nn as nn
from torchvision import transforms, datasets
from torch import optim as optim
from matplotlib import pyplot as plt
from sklearn.metrics import roc_auc_score
import math
import numpy as np
from tqdm import trange
from torch.autograd import Variable
from sklearn.datasets import load_boston
from sklearn import datasets
import random
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.metrics import accuracy_score
from  sklearn.model_selection  import  train_test_split
from sklearn import linear_model
from sklearn.metrics import mean_squared_error
import pandas as pd
from torch.optim import lr_scheduler
from sklearn import metrics
from tqdm import trange
import matplotlib.pyplot as plt
import seaborn as sns
import time
from sklearn.datasets import load_breast_cancer
from xgboost.sklearn import XGBClassifier,XGBRegressor
import seaborn as sns
from Marvell import KL_gradient_perturb_function_creator
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('Device is ', device)

##Model

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(1, 1),
            # nn.LeakyReLU(0.1),
            # nn.Linear(12, 1),
            nn.LeakyReLU(training_parameters['discriminator_ReLU']),
            # nn.Sigmoid()
        )

    def forward(self, x):
        output = self.net(x)
        return output.view(-1).to(device)

class Decoder(nn.Module):
    def __init__(self, input_dim=1, output_dim=1):
        super(Decoder, self).__init__()
        self.hidden_layer1 = nn.Sequential(
            nn.Linear(input_dim, 1),
            # nn.Sigmoid()
            nn.ReLU()
            # nn.LeakyReLU(0.1)
        )

        self.hidden_layer2 = nn.Sequential(
            nn.Linear(3, output_dim),
            # nn.Sigmoid()
            nn.ReLU()
        )

    def forward(self, x):
        output = self.hidden_layer1(x)
        # output = self.hidden_layer2(output)
        return output.to(device)


class Generator_GAFM():

    def __init__(self):
        super(Generator_GAFM, self).__init__()
        self.lr = training_parameters['learning_rate_generator']

    def forward(self, y):
        self.w1 = y
        s1 = M_fun(np.dot(np.ones(p_dim), self.w1))
        return s1

    def train(self, y, label, discriminator,decoder, it):
        gradients=0
        label = torch.tensor(label.values).float()
        for i in range(it):
            l1_true = self.forward(y)
            l1_true = torch.tensor(pd.DataFrame(l1_true).values).float()
            Param = []
            Param_C = []

            for p in discriminator.parameters():
                Param.append(p.detach().numpy()[0])


            for p in decoder.parameters():
                Param_C.append(p.detach().numpy()[0])


            w = Param[0][0]
            b = Param[1]
            w_C = Param_C[0][0]
            b_C = Param_C[1]

            # l1_err_1 = -(list(pd.DataFrame(discriminator(label).detach().numpy()).iloc[:, 0]) -
            #              discriminator(l1_true).detach().numpy())
            l1_err_1 = -(discriminator(l1_true).detach().numpy())

            l1_err_2 = np.multiply(l1_err_1, w * M_fun_der_D(w * l1_true + b).T[0])
            l1_err_3 = np.multiply(l1_err_2, M_fun_der(np.dot(np.ones(p_dim), y)))
            l1_err_recon_1 = -(list(pd.DataFrame(label.detach().numpy()).iloc[:, 0]) -
                               decoder(l1_true).detach().numpy().T)[0]
            l1_err_recon_2 = np.multiply(l1_err_recon_1,
                                         w_C * M_fun_der(w_C * l1_true + b_C).T[0])
            l1_err_recon_3 = np.multiply(l1_err_recon_2,
                                         M_fun_der(np.dot(np.ones(p_dim), y)))

            self.w1 -= self.lr * (l1_err_3+training_parameters['a']*l1_err_recon_3)
            y = self.w1
            gradients+=self.lr * (l1_err_3+training_parameters['a']*l1_err_recon_3)

        return self.w1,gradients

class Generator_Vanilla():

    def __init__(self):
        super(Generator_Vanilla, self).__init__()
        self.lr = training_parameters['learning_rate_generator']

    def forward(self, y):
        self.w1 = y
        s1 = M_fun(np.dot(np.ones(p_dim), self.w1))
        return s1

    def train(self, y, label, it):
        gradients=0
        for i in range(it):
            l1_true = self.forward(y)
            l1_err_1 = -(list(pd.DataFrame(label).iloc[:, 0]) -
                         l1_true)
            l1_err_2 = np.multiply(l1_err_1, 1)
            l1_err_3 = np.multiply(l1_err_2, M_fun_der(np.dot(np.ones(p_dim), y)))
            self.w1 -= self.lr * l1_err_3
            y = self.w1
            gradients+=self.lr * l1_err_3
        return self.w1,gradients

#Function
def M_fun_der(x):
    return np.where(x > 0, 1, training_parameters['generator_ReLU'])
def M_fun(x):
    return np.maximum(training_parameters['generator_ReLU'] * x, x)
def M_fun_der_D(x):
    return np.where(x > 0, 1, training_parameters['discriminator_ReLU'])
def M_fun_D(x):
    return np.maximum(training_parameters['discriminator_ReLU'] * x, x)
def sigmoid(x):
	return 1.0/(1.0 + np.exp(-x))
def sigmoid_der(x):
	return (1/(1+np.exp(-x)))*(1.0 - 1/(1+np.exp(-x)))


def generateNoise(a,b,N):
    #noise = np.random.uniform(0,1,N)
    noise = np.random.uniform(a, b, N)
    return noise.reshape(-1)


def generateLabel(input):
    Index=list(Y_train.index)
    sample_list=random.sample(list(input.index),int(input.shape[0]*0.5))
    Y=[]
    for i in range(input.shape[0]):
        if Index[i] in sample_list:
            Y.append(1)
        else:
            Y.append(0)
    res = pd.DataFrame(Y)
    res.index = input.index
    res.columns = ['label']
    return res




def addnoise(input,variance=1):
    noise = np.random.normal(0, variance, input.shape[0])
    Y=noise+np.array(input.iloc[:,0])
    res = pd.DataFrame(Y)
    res.index = input.index
    res.columns = ['label']
    return res



##Attack
def t(Y_fake, AUC):
    # print('before',Y_fake)
    Y_fake=Y_fake.reshape(-1)
    val_max = max(Y_fake)
    val_min = min(Y_fake)
    pred = (Y_fake - val_min + 1e-16) / (val_max - val_min + 1e-16)
    return pred

def MA(g,label):
    y=[]
    label_ = label.copy()
    label_.reset_index(inplace=True, drop=True)
    a = list(label_.iloc[:, 0])
    ind_1 = [i for i, j in enumerate(a) if j == 1]
    g_ = g
    a=np.mean([j for i, j in enumerate(g_) if i in ind_1])
    b = (sum(g_)-len(ind_1)*a)/(len(g)-len(ind_1))
    if abs(a)>abs(b):
        mean_1=a
        mean_0=b
    else:
        mean_1=b
        mean_0=a

    for i in range(len(g)):
        if (mean_1-g[i])**2<(mean_0-g[i])**2:
            y.append(1)
        else:
            y.append(0)
    return y

def NA(Y_fake):
    res = [np.sqrt(Y_fake[i] ** 2) for i in range(len(Y_fake))]
    val_max = max(res)
    val_min = min(res)
    pred = (res - val_min + 1e-16) / (val_max - val_min + 1e-16)
    return pred

##Flipping
def flip(input,frac=0):
    Y=list(input.iloc[:, 0])
    # print(Y)
    num=int(frac*len(Y))
    # print(num)
    sample_list=random.sample(list(range(len(Y))),num)
    # print(sample_list)
    for i in range(len(Y)):
        if i in sample_list:
            Y[i]=1-Y[i]
    res = pd.DataFrame(Y)
    res.index = input.index
    res.columns = ['label']
    return res

def reverse(input):
    for i in range(len(input)):
        if input[i]<0.5:
            input[i]=1-input[i]
    return input
def ratio(input1,input2):
    ratio_input=[]
    for i in range(len(input1)):
        ratio_input.append(input1[i]/(input2[i]+1e-16))
    return ratio_input


##Train model
def train(Y_Train,y1_hat,threshold=0.0001,frac=0,info=False):
    loss_fun = torch.nn.MSELoss()
    logreg1 = linear_model.LinearRegression()
    logreg1_O = linear_model.LinearRegression()
    logreg1_m = linear_model.LinearRegression()
    discriminator = Discriminator().to(device)
    generator_T = Generator_GAFM()
    generator_O = Generator_Vanilla()
    generator_m = Generator_Vanilla()
    decoder = Decoder().to(device)
    optimizer_D = optim.Adam(discriminator.parameters(), lr=training_parameters['learning_rate_discriminator'])
    optimizer_De = optim.Adam(decoder.parameters(), lr=training_parameters['learning_rate_decoder'])
    y1_hat_m=y1_hat_O = y1_hat
    g0_1 = y1_hat
    g0_1_O = y1_hat_O
    g0_1_m=y1_hat_m
    errors_p1 = []
    errors_p1_test = []
    errors_train = []
    errors_test = []

    errors_p1_O= []
    errors_p1_test_O = []
    errors_train_O = []
    errors_test_O = []

    Loss_G = []
    Loss_D = []
    Loss_De = []



    try:
        for i in range(1, training_parameters['n_epochs'] + 1):
            if info:print('-------------- Epoch', i, ' --------------')
            # Y_train=addnoise(Y_Train,frac)

            # ############Train Discriminator
            # y_hat =  generator(torch.tensor(pd.DataFrame(np.hstack((y1_hat, y2_hat, y3_hat))).values).float())#M_fun(y1_hat + y2_hat + y3_hat)
            y_hat = torch.tensor(
                M_fun(pd.DataFrame(y1_hat)).values).float()
            # print('Discriminator')
            loss_D = -torch.mean(discriminator(torch.tensor(Y_train.values).float())) + torch.mean(
                discriminator(y_hat.detach()))
            loss_D.backward()
            # print('loss_D', loss_D)
            optimizer_D.step()
            # print('loss_D',loss_D)
            for p in discriminator.parameters():
                p.data.clamp_(-clip_value, clip_value)

                # print('loss_D',loss_D)
            if i==1:
                cluster_s_O=y_hat.detach()
                cluster_s_dis=discriminator(y_hat.detach()).detach()
            if i==training_parameters['n_epochs']:
                cluster_e_O = y_hat.detach()
                cluster_e_dis = discriminator(y_hat.detach()).detach()

            Loss_D.append(loss_D.detach().numpy())
            # print('decoder')

            optimizer_De.zero_grad()
            y_de = decoder(y_hat.detach())
            loss_De = training_parameters['a']*loss_fun(y_de, torch.tensor(Y_train.values).float())
            loss_De.backward()
                # print('loss_De',loss_De)
            optimizer_De.step()
                # print('loss_De', loss_De)

            Loss_De.append(loss_De.detach().numpy())

            # ############Train Generator
            # print('Generator')
            pa,gradients = generator_T.train(y1_hat.T.reshape(1,-1),
                                   Y_train, discriminator, decoder, training_parameters['generator_epoch'])

            # print('pa',pa)

            pa_O,gradients_O = generator_O.train(y1_hat_O.T.reshape(1,-1),
                                     Y_Train, training_parameters['generator_epoch'])

            pa_m, gradients_m = generator_m.train(y1_hat_m.T.reshape(1, -1),
                                                  Y_Train, training_parameters['generator_epoch'])

            # pa_new_m=list(KL_gradient_perturb_function_creator(Y_Train,pa_m[0]))

            #####Leakage from Gradients
            ########## P1 p2 p3
            # print('Leakage from Gradients')



            gradient_1 = gradients

            gradient_1_O = gradients_O

            gradient_1_m = gradients_m

            ####Train participants
            y1_hat = pa[0]
            g0_1 = y1_hat


            loss_G = loss_fun(discriminator(y_hat.detach()), discriminator(torch.tensor(Y_train.values).float()))
            Loss_G.append(loss_G.detach().numpy())

            ##O
            y1_hat_O = pa_O[0]
            g0_1_O = y1_hat_O

            # y1_hat_m = pa_new_m
            # g0_1_m = y1_hat_m

            logreg1.fit(pd.DataFrame(X_train1), y1_hat)
            # print(np.array(Y_train.iloc[:,0]),len(np.array(list(Y_train.iloc[:,0]))))
            # print(logreg1.predict(pd.DataFrame(X_train1)),np.array(list(logreg1.predict(pd.DataFrame(X_train1)))),len(np.array(list(logreg1.predict(pd.DataFrame(X_train1))))))
            mse1 = pearsonr(np.array(Y_train.iloc[:,0]), np.array(logreg1.predict(pd.DataFrame(X_train1))))[0]
            y1_hat = logreg1.predict(pd.DataFrame(X_train1))
            Y_predict1 = logreg1.predict(pd.DataFrame(X_test1))
            mse1_test=pearsonr(np.array(list(Y_test.iloc[:,0])), np.array(list(Y_predict1)))[0]
            errors_p1.append((mse1))
            errors_p1_test.append((mse1_test))
            # print(errors_p1,errors_p1_test)
            # if info:print('train_mse1', mse1)

            logreg1_O.fit(pd.DataFrame(X_train1), y1_hat_O)
            y1_hat_O = logreg1_O.predict(pd.DataFrame(X_train1))
            Y_predict1_O = logreg1_O.predict(pd.DataFrame(X_test1))
            errors_p1_O.append(pearsonr(np.array(list(Y_train.iloc[:,0])), np.array(list(y1_hat_O)))[0])
            errors_p1_test_O.append(pearsonr(np.array(list(Y_test.iloc[:,0])), np.array(list(Y_predict1_O)))[0])
            # print('test_mse1', mse1)

            # logreg1_m.fit(pd.DataFrame(X_train1), y1_hat_m)
            # y1_hat_m = logreg1_m.predict(pd.DataFrame(X_train1))
            # Y_predict1_m = logreg1_m.predict(pd.DataFrame(X_test1))



            Y_train_hat_np = torch.tensor(pd.DataFrame(M_fun(y1_hat)).values).float()
            Y_train_hat_np = decoder(Y_train_hat_np)
            Y_test_hat_np = torch.tensor(pd.DataFrame(M_fun(Y_predict1)).values).float()
            Y_test_hat_np = decoder(Y_test_hat_np)


            mse_train = pearsonr(np.array(list(Y_train.iloc[:,0])), Y_train_hat_np.detach().numpy())[0]
            errors_train.append((mse_train))
            mse_test = pearsonr(np.array(list(Y_test.iloc[:,0])), Y_test_hat_np.detach().numpy())[0]
            errors_test.append((mse_test))

            mse_train_O = pearsonr(np.array(list(Y_train.iloc[:,0])), M_fun(y1_hat_O))[0]
            errors_train_O.append((mse_train_O))
            mse_test_O = pearsonr(np.array(list(Y_test.iloc[:,0])), M_fun(Y_predict1_O))[0]
            errors_test_O.append((mse_test_O))
            # if info:print("errors_train", mse_train)
            # if info:print("errors_test", mse_test)
            # mse_train_O = mean_squared_error(Y_Train, M_fun(y1_hat_O))
            # mse_test_O = mean_squared_error(Y_test, M_fun(Y_predict1_O))
            # if info:print("errors_train_O", mse_train_O)
            # if info:print("errors_test_O", mse_test_O)
            # mse_train_m = mean_squared_error(Y_Train, M_fun(y1_hat_m))
            # mse_test_m = mean_squared_error(Y_test, M_fun(Y_predict1_m))


    except ValueError:
        errors_p1,errors_p1_test,errors_train, errors_test, gradient_1,gradient_1_m, gradient_1_O, cluster_s_dis,cluster_e_O,cluster_e_dis,errors_p1_O,errors_p1_test_O,errors_train_O, errors_test_O = train(Y_Train,y1_hat,frac=frac)
    # print('np.var(AUC_train)',np.var(AUC_train),AUC_train[-1])
    # if np.var(AUC_train)<threshold or AUC_train[-1]<0.5:
    #     errors_p1,errors_train, errors_test, gradient_1,gradient_1_m, gradient_1_O, cluster_s_dis,cluster_e_O,cluster_e_dis= train(
    #         Y_Train, y1_hat,frac=frac)
    return errors_p1,errors_p1_test,errors_train, errors_test, gradient_1,gradient_1_m, gradient_1_O, cluster_s_dis,cluster_e_O,cluster_e_dis,errors_p1_O,errors_p1_test_O,errors_train_O, errors_test_O

def loop_train(Y_Train,Y_test,Y1_hat,N,frac=0,threshold=0.0001,info=False):
    best_AUC=float('inf')
    errors_train_loop = []
    errors_test_loop = []
    errors_p1_loop=[]
    errors_p1_test_loop = []
    errors_train_loop_O = []
    errors_test_loop_O = []
    errors_p1_loop_O = []
    errors_p1_test_loop_O = []
    gamma=training_parameters['a']
    print(f'-------------------------Gammma = {gamma}-------------------------')
    for i in trange(N):
        errors_p1,errors_p1_test,errors_train, errors_test, gradient_1,gradient_1_m, gradient_1_O, cluster_s_dis,cluster_e_O,cluster_e_dis,errors_p1_O,errors_p1_test_O,errors_train_O, errors_test_O = train(
            Y_Train=Y_Train, y1_hat=Y1_hat, frac=frac,threshold=threshold,info=info)



        errors_train_loop.append(errors_train[-1])
        errors_test_loop.append(errors_test[-1])
        errors_train_loop_O.append(errors_train_O[-1])
        errors_test_loop_O.append(errors_test_O[-1])
        errors_p1_loop.append(errors_p1[-1])
        errors_p1_test_loop.append(errors_p1_test[-1])
        errors_p1_loop_O.append(errors_p1_O[-1])
        errors_p1_test_loop_O.append(errors_p1_test_O[-1])
        if errors_train[-1] < best_AUC:

            plot_errors_train = errors_train
            plot_errors_test = errors_test
            plot_p1 = errors_p1
            plot_p1_test = errors_p1_test
            best_AUC=errors_train[-1]



    print('plot_errors_train', np.mean(errors_train_loop),np.min(errors_train_loop)-np.mean(errors_train_loop),np.max(errors_train_loop)-np.mean(errors_train_loop))
    print('plot_errors_test', np.mean(errors_test_loop), np.min(errors_test_loop)- np.mean(errors_test_loop), np.max(errors_test_loop)- np.mean(errors_test_loop))
    print('plot_p1', np.mean(errors_p1_loop),np.min(errors_p1_loop)-np.mean(errors_p1_loop),np.max(errors_p1_loop)-np.mean(errors_p1_loop))
    print('plot_p1_test', np.mean(errors_p1_test_loop),np.min(errors_p1_test_loop)-np.mean(errors_p1_test_loop),np.max(errors_p1_test_loop)-np.mean(errors_p1_test_loop))

    print('plot_errors_train_O', np.mean(errors_train_loop_O),np.min(errors_train_loop_O)-np.mean(errors_train_loop_O),np.max(errors_train_loop_O)-np.mean(errors_train_loop_O))
    print('plot_errors_test_O', np.mean(errors_test_loop_O),np.min(errors_test_loop_O)-np.mean(errors_test_loop_O),np.max(errors_test_loop_O)-np.mean(errors_test_loop_O))
    print('plot_p1_O', np.mean(errors_p1_loop_O),np.min(errors_p1_loop_O)-np.mean(errors_p1_loop_O),np.max(errors_p1_loop_O)-np.mean(errors_p1_loop_O))
    print('plot_p1_test_O', np.mean(errors_p1_test_loop_O),np.min(errors_p1_test_loop_O)-np.mean(errors_p1_test_loop_O),np.max(errors_p1_test_loop_O)-np.mean(errors_p1_test_loop_O))

    return errors_p1_loop,errors_p1_test_loop,errors_train_loop,errors_test_loop,plot_errors_train,plot_errors_test,plot_p1,plot_p1_test








